6 Fitting Models with parsnip

Published

September 17, 2025

Modified

September 18, 2025

parsnip包是tidymodels系列中的一个R包,它为各种不同的模型提供了流畅且标准化的接口。在本章中,我们将阐述为什么通用接口有助于在实际操作中理解和构建模型,并展示如何使用parsnip包。

具体来说,我们将重点介绍如何直接使用parsnip对象进行fit()predict()操作,这可能适用于一些简单的建模问题。下一章将介绍一种更优的方法,通过将模型和预处理器组合成一个称为workflow的对象,来处理许多建模任务。

Create a Model

一旦数据被编码成适合建模算法的格式(例如数值矩阵),就可以将其用于模型构建过程中。

假设我们最初选择的是线性回归模型。这相当于指定结果数据是数值型的,并且预测变量与结果之间的关系可以用简单的斜率和截距来表示:

\[y_i = \beta_0 + \beta_1 x_{1i} + \ldots + \beta_p x_{pi}\]

可以使用多种方法来估计模型参数:

  • 普通线性回归采用传统的最小二乘法来求解模型参数。

  • 正则化线性回归在最小二乘法中加入惩罚项,通过移除预测变量和/或将其系数缩小至零来追求模型的简洁性。这可以通过贝叶斯或非贝叶斯技术来实现。

Heterogeneous Interface

在R语言中,stats包可用于第一种情况。使用函数lm()进行线性回归的语法为model <- lm(formula, data, ...),其中...表示要传递给lm()的其他参数。该函数没有单独的xy的接口,而是直接假定结果是y,预测变量是x

对于第二种情况,即使用正则化进行估计,可以通过rstanarm包来拟合贝叶斯模型:model <- stan_glm(formula, data, family = "gaussian", ...)。在这种情况下,通过...传递的其他选项将包括参数的先验分布的参数以及有关模型数值方面的具体信息。与lm()一样,只有公式接口可用。

一种流行的非贝叶斯正则化回归方法是glmnet模型(Friedman, Hastie, Tibshirani 2010)。其语法如下:model <- glmnet(x = matrix, y = vector, family = "gaussian", ...)。在这种情况下,预测变量数据必须已经格式化为数值矩阵;这里有xy的单独接口,没有公式接口。

请注意,这些接口在数据传递给模型函数的方式或其参数方面存在异质性。第一个问题是,为了在不同的包中拟合模型,数据必须以不同的方式格式化。lm()stan_glm()只有公式接口,而glmnet()没有,对于其他类型的模型,其接口可能差异更大。对于试图进行数据分析的人来说,这些差异要求他们记住每个包的语法,这可能会非常令人沮丧。

对于tidymodels而言,指定模型的方法旨在更加统一:

  • 根据模型的数学结构指定其类型(例如,线性回归、随机森林、K近邻等)。

  • 指定用于拟合模型的引擎。这通常指的是使用的软件包,如Stanglmnet。这些本身就是模型,而parsnip通过将它们用作建模引擎来提供一致的接口。

  • 必要时,声明模型的模式。模式反映了预测结果的类型。对于数值型结果,模式为回归;对于定性结果,模式为分类。如果一种模型算法只能处理一种类型的预测结果(如线性回归),则其模式已预先设定。

这些规格的构建没有参考数据。例如,对于我们概述的三个案例:

library(tidymodels)
#> ── Attaching packages ─────────────────────────────────── tidymodels 1.4.1 ──
#> ✔ broom        1.0.9     ✔ recipes      1.3.1
#> ✔ dials        1.4.2     ✔ rsample      1.3.1
#> ✔ dplyr        1.1.4     ✔ tailor       0.1.0
#> ✔ ggplot2      3.5.2     ✔ tidyr        1.3.1
#> ✔ infer        1.0.9     ✔ tune         2.0.0
#> ✔ modeldata    1.5.1     ✔ workflows    1.3.0
#> ✔ parsnip      1.3.3     ✔ workflowsets 1.1.1
#> ✔ purrr        1.1.0     ✔ yardstick    1.3.2
#> ── Conflicts ────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()
tidymodels_prefer()

linear_reg() %>% set_engine("lm")
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: lm

linear_reg() %>% set_engine("glmnet")
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: glmnet

linear_reg() %>% set_engine("stan")
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: stan

Consistent Interface

一旦模型的细节确定后,就可以使用fit()函数(用于公式)或fit_xy()函数(当数据已预处理时)来进行模型估计。parsnip包允许用户不必在意底层模型的接口;即使建模包的函数只有x/y接口,你也始终可以使用公式。

translate()函数可以详细说明parsnip如何将用户的代码转换为该包的语法:

linear_reg() %>%
  set_engine("lm") %>%
  translate()
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: lm 
#> 
#> Model fit template:
#> stats::lm(formula = missing_arg(), data = missing_arg(), weights = missing_arg())

linear_reg(penalty = 1) %>%
  set_engine("glmnet") %>%
  translate()
#> Linear Regression Model Specification (regression)
#> 
#> Main Arguments:
#>   penalty = 1
#> 
#> Computational engine: glmnet 
#> 
#> Model fit template:
#> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(), 
#>     family = "gaussian")

linear_reg() %>%
  set_engine("stan") %>%
  translate()
#> Linear Regression Model Specification (regression)
#> 
#> Computational engine: stan 
#> 
#> Model fit template:
#> rstanarm::stan_glm(formula = missing_arg(), data = missing_arg(), 
#>     weights = missing_arg(), family = stats::gaussian, refresh = 0)

注意,missing_arg()只是尚未提供的数据的占位符。我们为glmnet引擎提供了一个必需的penalty参数。此外,对于Stan和glmnet引擎,family参数作为默认值被自动添加。正如本节后面将展示的,这个选项是可以更改的。

让我们逐步了解如何仅将经度和纬度作为函数来预测艾姆斯数据中房屋的销售价格:

data(ames)
ames <- ames %>% mutate(Sale_Price = log10(Sale_Price))

set.seed(502)
ames_split <- initial_split(ames, prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
lm_model <-
  linear_reg() %>%
  set_engine("lm")

lm_form_fit <-
  lm_model %>%
  # Recall that Sale_Price has been pre-logged
  fit(Sale_Price ~ Longitude + Latitude, data = ames_train)

lm_xy_fit <-
  lm_model %>%
  fit_xy(
    x = ames_train %>% select(Longitude, Latitude),
    y = ames_train %>% pull(Sale_Price)
  )

lm_form_fit
#> parsnip model object
#> 
#> 
#> Call:
#> stats::lm(formula = Sale_Price ~ Longitude + Latitude, data = data)
#> 
#> Coefficients:
#> (Intercept)    Longitude     Latitude  
#>    -302.974       -2.075        2.710

lm_xy_fit
#> parsnip model object
#> 
#> 
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#> 
#> Coefficients:
#> (Intercept)    Longitude     Latitude  
#>    -302.974       -2.075        2.710

Consistent Parameters

parsnip不仅为不同的包提供了一致的模型接口,还在模型参数方面保持了一致性。拟合相同模型的不同函数往往具有不同的参数名称,这是很常见的情况。随机森林模型函数就是一个很好的例子。三个常用的参数分别是集成中的树的数量、在树的每次分裂时随机抽样的预测变量数量,以及进行分裂所需的数据点数量。对于实现该算法的三个不同的R包,这些参数如 Table 1 所示。

Table 1: Example argument names for different random forest functions.
Argument Type ranger randomForest sparklyr
# sampled predictors mtry mtry feature_subset_strategy
# trees num.trees ntree num_trees
# data points to split min.node.size nodesize min_instances_per_node

为了减轻参数指定的麻烦,parsnip在包内部和包之间使用通用的参数名称。Table 2 展示了parsnip模型在随机森林中所使用的参数。

Table 2: Random forest argument names used by parsnip.
Argument Type parsnip
# sampled predictors mtry
# trees trees
# data points to split min_n

诚然,这是又一组需要记住的参数。不过,当其他类型的模型具有相同的参数类型时,这些名称仍然适用。例如,梯度提升树集成也会创建大量基于树的模型,因此在那里也会使用treesmin_n也是如此,依此类推。

一些原始参数名称可能相当专业。例如,在glmnet模型中,为了指定要使用的正则化量,会用到希腊字母lambda。虽然这种数学符号在统计学文献中很常用,但很多人并不清楚lambda代表什么(尤其是那些使用模型结果的人)。由于这是正则化中使用的惩罚项,parsnip将参数名称标准化为penalty。同样,KNN模型中的邻居数量被称为neighbors,而不是k。我们在标准化参数名称时的经验法则是:如果从业者要将这些名称包含在图表或表格中,查看这些结果的人会理解这些名称吗?

要了解parsnip的参数名称如何对应原始名称,请使用模型的帮助文件(可通过?rand_forest获取)以及translate()函数:

rand_forest(trees = 1000, min_n = 5) %>%
  set_engine("ranger") %>%
  set_mode("regression") %>%
  translate()
#> Random Forest Model Specification (regression)
#> 
#> Main Arguments:
#>   trees = 1000
#>   min_n = 5
#> 
#> Computational engine: ranger 
#> 
#> Model fit template:
#> ranger::ranger(x = missing_arg(), y = missing_arg(), weights = missing_arg(), 
#>     num.trees = 1000, min.node.size = min_rows(~5, x), num.threads = 1, 
#>     verbose = FALSE, seed = sample.int(10^5, 1))

parsnip中的建模函数将模型参数分为两类:

  • 主要参数更为常用,且往往在不同引擎中都可使用。

  • 引擎参数要么是特定于某个引擎的,要么使用频率较低。

例如,在之前随机森林代码的转换中,参数num.threadsverboseseed是默认添加的。这些参数是随机森林模型range实现所特有的,作为主要参数是不合理的。特定于引擎的参数可以在 set_engine()中指定。例如,要让ranger::ranger()函数打印出更多关于拟合的信息:

rand_forest(trees = 1000, min_n = 5) %>%
  set_engine("ranger", verbose = TRUE) %>%
  set_mode("regression") %>%
  translate()
#> Random Forest Model Specification (regression)
#> 
#> Main Arguments:
#>   trees = 1000
#>   min_n = 5
#> 
#> Engine-Specific Arguments:
#>   verbose = TRUE
#> 
#> Computational engine: ranger 
#> 
#> Model fit template:
#> ranger::ranger(x = missing_arg(), y = missing_arg(), weights = missing_arg(), 
#>     num.trees = 1000, min.node.size = min_rows(~5, x), verbose = TRUE, 
#>     num.threads = 1, seed = sample.int(10^5, 1))

Use the Model Results

一旦模型创建并拟合完成,我们可以通过多种方式使用其结果;我们可能想要绘制、打印或以其他方式检查模型输出。parsnip模型对象中存储了多个量,包括拟合好的模型。这可以在一个名为fit的元素中找到,该元素可通过extract_fit_engine()函数返回:

lm_form_fit %>% extract_fit_engine()
#> 
#> Call:
#> stats::lm(formula = Sale_Price ~ Longitude + Latitude, data = data)
#> 
#> Coefficients:
#> (Intercept)    Longitude     Latitude  
#>    -302.974       -2.075        2.710

常规方法可应用于该对象,例如打印和绘图:

lm_form_fit %>%
  extract_fit_engine() %>%
  vcov()
#>             (Intercept)     Longitude      Latitude
#> (Intercept)  207.311311  1.5746587743 -1.4239709610
#> Longitude      1.574659  0.0165462548 -0.0005999802
#> Latitude      -1.423971 -0.0005999802  0.0325397353

永远不要将parsnip模型的fit元素传递给模型预测函数,也就是说,使用predict(lm_form_fit),而不是predict(lm_form_fit$fit)。如果数据经过了任何预处理,将会产生错误的预测结果(有时不会出现错误提示)。底层模型的预测函数并不知道在运行模型之前是否对数据进行了任何转换。有关预测的更多内容,请参见第6.3节。

base R中一些现有方法存在的一个问题是,结果的存储方式可能并非最实用。例如,针对lm对象的summary()方法可用于打印模型拟合结果,包括一个包含参数值、其不确定性估计以及p值的表格。这些特定结果也可以保存:

model_res <-
  lm_form_fit %>%
  extract_fit_engine() %>%
  summary()

# The model coefficient table is accessible via the `coef` method.
param_est <- coef(model_res)
class(param_est)
#> [1] "matrix" "array"
param_est
#>                Estimate Std. Error   t value     Pr(>|t|)
#> (Intercept) -302.973554 14.3983093 -21.04230 3.640103e-90
#> Longitude     -2.074862  0.1286322 -16.13019 1.395257e-55
#> Latitude       2.709654  0.1803877  15.02128 9.289500e-49

关于这个结果,有几点需要注意。首先,该对象是一个数值矩阵。选择这种数据结构很可能是因为所有计算结果都是数值型的,而且矩阵对象比数据框的存储效率更高。这种选择或许是在20世纪70年代末做出的,当时计算效率至关重要。其次,非数值数据(系数的标签)包含在行名中。将参数标签作为行名,这与原始S语言中的约定非常一致。

合理的下一步可能是创建参数值的可视化。要做到这一点,将参数矩阵转换为数据框是明智的。我们可以将行名作为一列添加进去,这样它们就可以在图表中使用了。然而,请注意,现有的几个矩阵列名对于普通数据框来说不是有效的R列名(例如,"Pr(>|t|)")。另一个复杂之处在于列名的一致性。对于lm对象,p值所在的列是"Pr(>|t|)",但对于其他模型,可能会使用不同的检验,因此列名会有所不同(例如,"Pr(>|z|)"),并且检验类型会编码在列名中。

虽然这些额外的数据格式化步骤并非无法克服,但它们确实是一种阻碍,尤其是因为对于不同类型的模型,这些步骤可能会有所不同。矩阵并非一种可高度复用的数据结构,主要原因是它会将数据限制为单一类型(例如数值型)。此外,将部分数据保存在维度名称中也存在问题,因为这些数据必须经过提取才能具有普遍用途。

作为一种解决方案,broom包可以将多种类型的模型对象转换为整洁的结构。例如,在线性模型上使用tidy()方法会生成:

tidy(lm_form_fit)
#> # A tibble: 3 × 5
#>   term        estimate std.error statistic  p.value
#>   <chr>          <dbl>     <dbl>     <dbl>    <dbl>
#> 1 (Intercept)  -303.      14.4       -21.0 3.64e-90
#> 2 Longitude      -2.07     0.129     -16.1 1.40e-55
#> 3 Latitude        2.71     0.180      15.0 9.29e-49

列名在各个模型中是标准化的,不包含任何额外数据(例如统计检验的类型)。以前包含在行名中的数据现在位于一个名为term的列中。tidymodels生态系统中的一个重要原则是,函数返回的值应该具有可预测性、一致性和不出意料的特点。

Make Predictions

parsnip与传统R建模函数的另一个不同之处在于predict()返回值的格式。对于预测,parsnip始终遵循以下规则:

  1. 结果始终是一个tibble( tibble 是R语言中一种数据框格式)。

  2. 该tibble的列名始终是可预测的。

  3. 该tibble中的行数始终与输入数据集的行数相同。

例如,在预测数值型数据时:

ames_test_small <- ames_test %>% slice(1:5)
predict(lm_form_fit, new_data = ames_test_small)
#> # A tibble: 5 × 1
#>   .pred
#>   <dbl>
#> 1  5.22
#> 2  5.21
#> 3  5.28
#> 4  5.27
#> 5  5.28

预测结果的行顺序始终与原始数据相同。为什么有些列名前面有圆点?一些tidyverse和tidymodels的参数及返回值包含句点。这是为了防止合并具有重复名称的数据。有些数据集包含名为pred的预测变量!

这三条规则使将预测结果与原始数据合并变得更加容易:

ames_test_small %>%
  select(Sale_Price) %>%
  bind_cols(predict(lm_form_fit, ames_test_small)) %>%
  # Add 95% prediction intervals to the results:
  bind_cols(predict(lm_form_fit, ames_test_small, type = "pred_int"))
#> # A tibble: 5 × 4
#>   Sale_Price .pred .pred_lower .pred_upper
#>        <dbl> <dbl>       <dbl>       <dbl>
#> 1       5.02  5.22        4.91        5.54
#> 2       5.39  5.21        4.90        5.53
#> 3       5.28  5.28        4.97        5.60
#> 4       5.28  5.27        4.96        5.59
#> 5       5.28  5.28        4.97        5.60

第一条规则的动机源于一些R包的预测函数会生成不同的数据类型。例如,ranger包是计算随机森林模型的出色工具。然而,它返回的不是数据框或向量形式的输出,而是一个专用对象,其中嵌入了多个值(包括预测值)。这给数据分析师在脚本中处理时又增加了一个步骤。再举一个例子,原生的glmnet模型根据模型的具体情况和数据特征,至少可以返回四种不同的预测输出类型。这些类型如 Table 3 所示。

Table 3: Different return values for glmnet prediction types.
Type of Prediction Returns a:
numeric numeric matrix
class character matrix
probability (2 classes) numeric matrix (2nd level only)
probability (3+ classes) 3D numeric array (all levels)

此外,结果的列名包含编码值,这些编码值对应于glmnet模型对象中一个名为lambda的向量。这种出色的统计方法在实际使用中可能会令人却步,因为分析师可能会遇到各种特殊情况,而这些情况需要额外编写代码才能让该方法发挥作用。

对于第二个 tidymodels 预测规则,不同预测类型的可预测列名如 Table 4 所示。

Table 4: The tidymodels mapping of prediction types and column names.
type value column name(s)
numeric .pred
class .pred_class
prob .pred_{class levels}
conf_int .pred_lower, .pred_upper
pred_int .pred_lower, .pred_upper

关于输出中行数的第三条规则至关重要。例如,如果新数据的任何行包含缺失值,输出将为这些行填充缺失结果。parsnip中对模型接口和预测类型进行标准化的一个主要优势是,当使用不同的模型时,语法是相同的。假设我们使用决策树对艾姆斯的数据进行建模。在模型规格之外,代码流程没有显著差异:

tree_model <-
  decision_tree(min_n = 2) %>%
  set_engine("rpart") %>%
  set_mode("regression")

tree_fit <-
  tree_model %>%
  fit(Sale_Price ~ Longitude + Latitude, data = ames_train)

ames_test_small %>%
  select(Sale_Price) %>%
  bind_cols(predict(tree_fit, ames_test_small))
#> # A tibble: 5 × 2
#>   Sale_Price .pred
#>        <dbl> <dbl>
#> 1       5.02  5.15
#> 2       5.39  5.15
#> 3       5.28  5.32
#> 4       5.28  5.32
#> 5       5.28  5.32

这体现了在不同模型间使数据分析流程和语法同质化的好处。它能让用户将时间花在结果和解读上,而非不得不专注于R包之间的语法差异。

parsnip-Extension Packages

parsnip包本身包含了多个模型的接口。不过,为了便于包的安装和维护,还有其他tidymodels包提供了针对其他模型集的parsnip模型定义。discrim包包含了一组称为判别分析方法(如线性或二次判别分析)的分类技术的模型定义。通过这种方式,安装parsnip所需的包依赖得以减少。所有可与parsnip配合使用的模型(涵盖CRAN上的不同包)的列表可在 https://www.tidymodels.org/find/ 找到。

Creating Model Specifications

编写许多模型规格,或者记住如何编写生成它们的代码,可能会变得很繁琐。parsnip包包含一个RStudio插件,它可以提供帮助。无论是从插件工具栏菜单中选择这个插件,还是运行以下代码:

parsnip_addin()

会在RStudio集成开发环境的查看器面板中打开一个窗口,其中包含每种模型模式的可能模型列表。这些模型可以被写入源代码面板。

模型列表包含来自CRAN上的parsnip包和parsnip扩展包中的模型。

Chapter Summary

本章介绍了parsnip包,该包使用标准语法为多个R包中的模型提供了统一接口。此接口及生成的对象具有可预测的结构。

我们接下来将要使用的用于对Ames数据进行建模的代码如下:

library(tidymodels)
data(ames)
ames <- mutate(ames, Sale_Price = log10(Sale_Price))

set.seed(502)
ames_split <- initial_split(ames, prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test  <-  testing(ames_split)

lm_model <- linear_reg() %>% set_engine("lm")
Back to top